import numpy as np
from scipy.spatial import distance_matrix
from autode.geom import get_rot_mat_euler


def strip_identical_and_inv_axes(axes, sim_axis_tol):
    """
    For a list of axes remove those which are similar to within some distance
    tolerance, or are inverses to within that tolerance

    ---------------------------------------------------------------------------
    Arguments:
        axes: list of axes
        sim_axis_tol: distance tolerance in Å

    Returns:
        (list(np.ndarray)):
    """

    unique_possible_axes = []

    for i in range(len(axes)):
        unique = True
        for unique_axis in unique_possible_axes:
            if np.linalg.norm(axes[i] - unique_axis) < sim_axis_tol:
                unique = False
            if np.linalg.norm(-axes[i] - unique_axis) < sim_axis_tol:
                unique = False
        if unique:
            unique_possible_axes.append(axes[i])

    return unique_possible_axes


def get_possible_axes(coords, max_triple_dist=2.0, sim_axis_tol=0.1):
    r"""
    Possible rotation axes in a molecule. Currently limited to average vectors
    and cross products i.e.::

          Y          Y --->
         / \        / \
        X   Y      X   Z

          |
          |
          ,

    ---------------------------------------------------------------------------
    Arguments:
        coords (np.ndarray):

        max_triple_dist (float):

        sim_axis_tol (float):

    Returns:
        (list(np.ndarray)):
    """

    possible_axes = []
    n_atoms = len(coords)

    for i in range(n_atoms):
        for j in range(n_atoms):
            if i > j:  # For the unique pairs add the i–j vector
                vec = coords[j] - coords[i]
                vec /= np.linalg.norm(vec)
                possible_axes.append(vec)

            for k in range(n_atoms):
                # Triple must not have any of the same atoms
                if any((i == j, i == k, j == k)):
                    continue

                vec1 = coords[j] - coords[i]
                vec2 = coords[k] - coords[i]
                if all(
                    np.linalg.norm(vec) < max_triple_dist
                    for vec in (vec1, vec2)
                ):
                    avg_vec = (vec1 + vec2) / 2.0
                    possible_axes.append(avg_vec / np.linalg.norm(avg_vec))

                    perp_vec = np.cross(vec1, vec2)
                    possible_axes.append(perp_vec / np.linalg.norm(perp_vec))

    unique_possible_axes = strip_identical_and_inv_axes(
        possible_axes, sim_axis_tol
    )

    return unique_possible_axes


def is_same_under_n_fold(
    pcoords, axis, n, m=1, tol=0.25, excluded_pcoords=None
):
    """
    Does applying an n-fold rotation about an axis generate the same structure
    back again?

    ---------------------------------------------------------------------------
    Arguments:
        pcoords (np.ndarray): shape = (n_unique_atom_types, n_atoms, 3)

        axis (np.ndarray): shape = (3,)

        n (int): n-fold of this rotation

        m (int): Apply this n-fold rotation m times

        tol (float):

        excluded_pcoords (list):

    Returns:
        (bool):
    """
    n_unique, n_atoms, _ = pcoords.shape
    rotated_coords = np.array(pcoords, copy=True)

    rot_mat = get_rot_mat_euler(axis, theta=(2.0 * np.pi * m / n))

    excluded = [False for _ in range(n_unique)]

    for i in range(n_unique):
        # Rotate these coordinates
        rotated_coords[i] = rot_mat.dot(rotated_coords[i].T).T

        dist_mat = distance_matrix(pcoords[i], rotated_coords[i])

        # If all elements are identical then carry on with the next element
        if np.linalg.norm(dist_mat) < tol:
            continue

        # If the RMS between the closest pairwise distance for each atom is
        # above the threshold then these structures are not the same
        if np.linalg.norm(np.min(dist_mat, axis=1)) > tol:
            return False

        if excluded_pcoords is not None:
            # If these rotated coordinates are similar to those on the excluded
            # list then these should not be considered identical
            if any(
                np.linalg.norm(rotated_coords[i] - pcoords[i]) < tol
                for pcoords in excluded_pcoords
            ):
                excluded[i] = True

    # This permutation has already been found - return False even though
    # it's the same, because there is an excluded list
    if all(excluded):
        return False

    # Add to a list of structures that have already been generated by rotations
    if excluded_pcoords is not None:
        excluded_pcoords.append(rotated_coords)

    return True


def cn_and_axes(species, pcoords, max_n, dist_tol):
    """
    Find the highest symmetry rotation axis

    ---------------------------------------------------------------------------
    Arguments:
        species (autode.species.Species):

        max_n (int):

        dist_tol (float):

    Returns:
        (dict(int: np.ndarray)):
    """
    axes = get_possible_axes(coords=species.coordinates)

    # Cn numbers and their associated axes
    cn_assos_axes = {i: [] for i in range(2, max_n + 1)}

    for axis in axes:
        # Minimum n-fold rotation is 2
        for n in range(2, max_n + 1):
            if is_same_under_n_fold(pcoords, axis, n=n, tol=dist_tol):
                cn_assos_axes[n].append(axis)

    return cn_assos_axes


def create_pcoords(species):
    """
    Return a tensor where the first dimension is the size of the number of
    unique atom types in a molecule, the second, the atoms of that type
    and the third the number of dimensions in the coordinate space (3)

    :return: (np.ndarray) shape (n, m, 3)
    """
    atom_symbols = list(set(atom.label for atom in species.atoms))
    n_symbols = len(atom_symbols)

    pcoords = np.zeros(shape=(n_symbols, species.n_atoms, 3))

    for i in range(n_symbols):
        for j in range(species.n_atoms):
            # Atom symbol needs to match the leading dimension
            if species.atoms[j].label != atom_symbols[i]:
                continue

            pcoords[i, j, :] = species.atoms[j].coord

    return pcoords


def symmetry_number(species, max_n_fold_rot_searched=6, dist_tol=0.25):
    """
    Calculate the symmetry number of a molecule. See:
    Theor Chem Account (2007) 118:813–826. 10.1007/s00214-007-0328-0

    ---------------------------------------------------------------------------
    Arguments:
        species (autode.atoms.Species):

    Keyword Arguments:
        max_n_fold_rot_searched (int):

        dist_tol (float): Distance tolerance (Å)

    Returns:
        (int):
    """
    species.translate(vec=-species.com)
    pcoords = create_pcoords(species)

    # Get the highest Cn-fold rotation axis
    cn_axes = cn_and_axes(species, pcoords, max_n_fold_rot_searched, dist_tol)

    # If there are no C2 or greater axes then this molecule is C1  → σ=1
    if all(len(cn_axes[n]) == 0 for n in cn_axes.keys()):
        return 1

    sigma_r = 1  # Already has E symmetry

    added_pcoords = []

    # For every possible axis apply C2, C3...C_n_max rotations
    for n, axes in cn_axes.items():
        for axis in axes:
            # Apply this rotation m times e.g. once for a C2 etc.
            for m in range(1, n):
                # If the structure is the same but and has *not* been generated
                # by another rotation increment the symmetry number by 1
                if is_same_under_n_fold(
                    pcoords,
                    axis,
                    n=n,
                    m=m,
                    tol=dist_tol,
                    excluded_pcoords=added_pcoords,
                ):
                    sigma_r += 1

    if species.is_linear():
        # There are perpendicular C2s the point group is D∞h
        if sigma_r > 2:
            return 2

        # If not then C∞v and the symmetry number is 1
        else:
            return 1

    return sigma_r
